//-*-C++-*-
/***************************************************************************
 *
 *   Copyright (C) 2025 by Will Gauvin
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

#ifndef __dsp_RescaleMedianMadCalculator_h
#define __dsp_RescaleMedianMadCalculator_h

#include "dsp/Rescale.h"

namespace dsp
{
  /**
   * @brief a class used to calculate the median and median absolute deviation (MAD) of sample data
   *
   * This class needs to store the input sample temporarily to allow for the compute method to
   * perform the actual computation, unlike the RescaleMeanStdCalculator, that stores a running total
   * of sum and sum squared.
   *
   * The class uses the median of medians algorithm to optimise finding the median and then the MAD. Using
   * this algorithm has on average O(n) time complexity but at worse O(n^2) but given that the data is
   * mostly white noise this would allow using O(n).
   *
   * @see https://en.wikipedia.org/wiki/Median_of_medians#Algorithm for details of the algorithm.
   */
  class RescaleMedianMadCalculator : public Rescale::ScaleOffsetCalculator {

  public:
    static constexpr float DEFAULT_SCALE_FACTOR = 0.6744898;

    //! Default constructor
    RescaleMedianMadCalculator() = default;

    //! Default destructor
    virtual ~RescaleMedianMadCalculator() = default;

    /**
     * @brief initialise the calculator to allow for allocation of arrays and/or scratch space
     *
     * @param input the input timeseries of the Rescale operation, needed to get correct nchan, npol and ndim
     * @param ndat the number of samples to be used to calculate the scales and offset for
     * @param output_time_total flag that instructs the Calculator to compute an F-scrunch time series of rescaled data
     */
    void init(const dsp::TimeSeries* input, uint64_t nsample, bool output_time_total) override;

    /**
     * @brief used to sample the current input timeseries to allow calculation of scales and offsets
     *
     * This method stores all the samples to later be used in calculating the median
     *
     * @param input the input timeseries to sample the data for, it may be either FPT or TFP ordered data.
     * @param start_dat the start time sample, may not be 0
     * @param end_dat the end time sample, may not be input->get_ndat() but may be less than that.
     * @param output_time_total an indicator of whether to accumulate the time sample values.
     * @returns the ending time sample (should be end_dat)
     */
    uint64_t sample_data(const dsp::TimeSeries* input, uint64_t start_dat, uint64_t end_dat, bool output_time_total) override;

    /**
     * @brief computes the scale and offset for each channel and polarisation of the sample data.
     *
     * @param nsample the total number of samples that Rescale has used to sample data.
     */
    void compute(uint64_t nsample) override;

    /**
     * @brief resets the sampled data
     */
    void reset_sample_data() override;

    /**
     * @brief get a pointer to the mean values for the channels of given polarisations
     */
    const double* get_mean(unsigned ipol) const override;

    /**
     * @brief get a pointer to the variance values for the channels of given polarisations
     */
    const double* get_variance(unsigned ipol) const override;

    /**
     * @brief get the scale factor used to convert the MAD to an estimate of the standard deviation
     */
    float get_scale_factor() const { return scale_factor; }

    /**
     * @brief set the scale factor used to convert the MAD to an estimate of the standard deviation
     *
     * @param scale_factor the new scale factor
     * @throws Error if scale_factor is zero.
     */
    void set_scale_factor(float _scale_factor);

  private:
    /**
     * @brief find the median of the data.
     *
     * This doesn't get the true median because if n is even it gets the (n - 1) / 2
     * element, using 0-offset indexing.
     *
     * @param data the data to find the median of
     * @param n the length of the data
     */
    float find_median(float *data, uint64_t n);

    /**
     * @brief select the n'th smallest value of the data
     *
     * Assertion: left <= n <= right
     *
     * This method loops through the following until either left == right
     * or that a pivot idx equals the desired n.
     *
     * a) if left == right, return left (desired median is not in the data)
     * b) find the index of the median of medians @see pivot
     * c) partitions the data to get bound index @see partition
     * d) if partition index is n, return n
     *    else if n < partition idx
     *      decrement right
     *    else
     *      increment left
     *
     * @param data the data used for finding the median
     * @param left the lower bound index of the data that may need to be sorted
     * @param right the upper bound index of the data that may need to be sorted
     * @param n the index of where
     *
     * @returns either left or n depending on exit condition of the loop.
     */
    uint64_t select(float *data, uint64_t left, uint64_t right, uint64_t n);

    /**
     * @brief pivot the data to find the median of medians.
     *
     * This the actual median-of-medians algorithm. It divides data between
     * indicies left and right (inclusive) into groups of at most 5 elements (the last
     * group may be less than 5), and finds the medians of each of the groups by
     * calling partition5. It then uses those medians to find the median that group.
     *
     * @param data the data used for finding the median
     * @param left the lower bound index to perform the operation on
     * @param right the upper bound index to perform the operation on
     *
     * @return the index of where the pivot point
     */
    uint64_t pivot(float *data, uint64_t left, uint64_t right);

    /**
     * @brief partitions the data into a 3-way partition.
     *
     * This call will partition the data into 3 groups such that:
     *
     *  a) First group between have values less than the pivot value (i.e. value at pivot_idx).
     *  b) Second group such that the values all have the same value (the pivot value).
     *  c) Third groups such that the values all have a value greater than the pivot value.
     *
     * This algorithm will return an index value based on the following:
     *
     *  a) if the index n is the first group, the index returned the 'right' of this group. Values to any index
     *    higher than this will be in the 2nd or 3rd group
     *  b) if the index n is in the second group, then return n (i.e. we have found the desired median)
     *  c) else the index returns is the 'left' index of the 3rd group. Values to any index lower than this will
     *    be in either groups 1 or 2.
     *
     * @param data the data used for finding the median
     * @param left the lower bound index to perform the operation on
     * @param right the upper bound index to perform the operation on
     * @param pivot_idx the index where all values at indices left <= pivot_idx will be less
     *   than or equal to the value at pivot_idx, and similarly values between pivot_idx <= right
     *   will be greater than or equal to pivot value.
     *
     * @return the bound index based on which partition group that the value at index n belongs to.
     */
    uint64_t partition(float *data, uint64_t left, uint64_t right, uint64_t pivot_idx, uint64_t n);

    /**
     * @brief performs a partitioning of data when at most 5 values are used.
     *
     * This is an optimised method that will partition the data such that the
     * median value is at the index (right - left) / 2 and all values less than or equal
     * that are between left and the partition index and all values greater or equal
     * to are between the partition index and right.
     *
     * @param data the data used for finding the median
     * @param left the lower bound index to perform the operation on
     * @param right the upper bound index to perform the operation on
     *
     * @returns the median index, equal to (right - left) / 2
     */
    uint64_t partition5(float *data, uint64_t left, uint64_t right);

    //! the sampled data, ordered by [pol][chan][sample]
    std::vector<std::vector<std::vector<float>>> _sample_data;

    //! the calculated absolute deviation from the median, ordered by [pol][chan][sample]
    std::vector<std::vector<std::vector<float>>> absolute_deviation;

    //! the mean, order by [pol][chan]
    std::vector<std::vector<double>> mean;

    //! the variance, order by [pol][chan]
    std::vector<std::vector<double>> variance;

    //! scale factor to use for scaling the MAD to being an estimate of the standard deviation
    float scale_factor{DEFAULT_SCALE_FACTOR};
  };

} // namespace dsp

#endif // __dsp_RescaleMedianMadCalculator_h
